LlamaIndexを完全に理解するチュートリアル その4:ListIndexで埋め込みベクトルを使用する方法

LlamaIndexを完全に理解するチュートリアル その4:ListIndexで埋め込みベクトルを使用する方法

ListIndexで埋め込みベクトルを使ってみよう。
Clock Icon2023.05.28

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんちには。

データアナリティクス事業本部 インテグレーション部 機械学習チームの中村です。

「LlamaIndexを完全に理解するチュートリアル その4」では、GPTListIndexで埋め込みベクトルを使う方法を見ていきます。

本記事で使用する用語は以下のその1で説明していますので、そちらも参照ください。

LlamaIndexを完全に理解するチュートリアル
その1:処理の概念や流れを理解する基礎編(v0.7.9対応)
その2:テキスト分割のカスタマイズ
その3:CallbackManagerで内部動作の把握やデバッグを可能にする
その4:ListIndexで埋め込みベクトルを使用する方法

・本記事の内容はその1のv0.7.9版の記事を投稿後、v0.7.9で動作するように修正しています

本記事の内容

LlamaIndexのGPTListIndexは通常、埋め込みベクトルは使用せず全てのノードを使って処理をシマス。

ただしオプションとしては準備されており、クエリとノードの埋め込みの類似度を求め、使用するノードを決定することは可能となっています。

今回はそれを実現する設定方法を見ていきます。

環境準備

その1と同様の方法で準備します。

使用したバージョン情報は以下となります。

  • Python : 3.10.11
  • langchain : 0.0.234
  • llama-index : 0.7.9
  • openai : 0.27.8

サンプルコード

ベースのサンプルは以下とします。ノードの選択状況がわかりやすいよう、LlamaDebugHandlerをCallbackManagerに設定しておきます。

from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler

documents = SimpleDirectoryReader(input_dir="./data").load_data()

llama_debug_handler = LlamaDebugHandler()
callback_manager = CallbackManager([llama_debug_handler])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)

list_index = ListIndex.from_documents(documents
    , service_context=service_context)

query_engine = list_index.as_query_engine()

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")

LlamaDebugHandlerの出力ログは以下の通りです。

**********
Trace: index_construction
    |_node_parsing ->  0.028296 seconds
      |_chunking ->  0.015003 seconds
      |_chunking ->  0.012298 seconds
**********
**********
Trace: query
    |_query ->  41.653571 seconds
      |_retrieve ->  0.0 seconds
      |_synthesize ->  41.653571 seconds
        |_llm ->  9.249877 seconds
        |_llm ->  12.309853 seconds
        |_llm ->  20.035586 seconds
**********

デフォルトのListIndexは全ノードを使っていることの確認

使用されるノードの選択状況は、LlamaDebugHandlerのRETRIEVEの結果またはresponse.source_nodesから把握することができます。

from llama_index.callbacks import CBEventType

node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"]
# node_list = response.source_nodes # こちらでも可
node_count = len(node_list)
print(f"{node_count=}")

for node in node_list:
    doc_id = node.node.id_
    print(f"{doc_id=}")
node_count=8
doc_id='ba03d238-41ad-4a3a-adfc-3d8ac6cdde48'
doc_id='c2cd2c7c-f6a1-4a3a-83d5-cb77e56c56b4'
doc_id='80ef0554-de26-4149-985d-cd3594b521e1'
doc_id='d8e911e9-e800-424d-9aba-c438e11c11a6'
doc_id='2db8f1cf-c766-47d6-88b6-13d71a23e678'
doc_id='0810e0ee-b1b5-4d00-9652-4cabbafab5bf'
doc_id='8b271e44-b4d8-4b9e-9aa3-a5c6ab88c3f9'
doc_id='4d27b006-891e-4fdf-aad6-db6f8237998b'

選ばれたノード数は8個となっています。ListIndexに含まれるdoc_idの情報は以下で取得できます。

for doc_id,v in list_index.storage_context.docstore.docs.items():
    print(f"{doc_id=}")
doc_id='ba03d238-41ad-4a3a-adfc-3d8ac6cdde48'
doc_id='c2cd2c7c-f6a1-4a3a-83d5-cb77e56c56b4'
doc_id='80ef0554-de26-4149-985d-cd3594b521e1'
doc_id='d8e911e9-e800-424d-9aba-c438e11c11a6'
doc_id='2db8f1cf-c766-47d6-88b6-13d71a23e678'
doc_id='0810e0ee-b1b5-4d00-9652-4cabbafab5bf'
doc_id='8b271e44-b4d8-4b9e-9aa3-a5c6ab88c3f9'
doc_id='4d27b006-891e-4fdf-aad6-db6f8237998b'

一致していることが分かり、現状はListIndexの全てのノードをRETRIEVEで選択していることが分かります。

埋め込みベクトルを使って選択する

ListIndexのデフォルト動作は以上ですが、設定によりノードを埋め込みベクトルの類似度で選択することが可能となります。

そのためには、retriever_modeをListRetrieverMode.EMBEDDINGに設定すればOKです。

similarity_top_kも3に設定し、クエリとの類似度が高い順にノードを3つを選択してみます。

from llama_index import SimpleDirectoryReader
from llama_index import ListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.indices.list.base import ListRetrieverMode

documents = SimpleDirectoryReader(input_dir="./data").load_data()

llama_debug_handler = LlamaDebugHandler()
callback_manager = CallbackManager([llama_debug_handler])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)

list_index = ListIndex.from_documents(documents
    , service_context=service_context)

query_engine = list_index.as_query_engine(
    retriever_mode=ListRetrieverMode.EMBEDDING
    , similarity_top_k=3
)

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")

LlamaDebugHandlerの出力ログは以下の通りです。

**********
Trace: index_construction
    |_node_parsing ->  0.037637 seconds
      |_chunking ->  0.016638 seconds
      |_chunking ->  0.018999 seconds
**********
**********
Trace: query
    |_query ->  20.332703 seconds
      |_retrieve ->  9.652239 seconds
        |_embedding ->  0.462239 seconds
        |_embedding ->  0.205543 seconds
        |_embedding ->  0.24983 seconds
        |_embedding ->  7.059133 seconds
        |_embedding ->  0.237001 seconds
        |_embedding ->  0.362483 seconds
        |_embedding ->  0.542566 seconds
        |_embedding ->  0.218405 seconds
        |_embedding ->  0.306953 seconds
      |_synthesize ->  10.679464 seconds
        |_llm ->  10.652076 seconds
**********

RETRIEVE時にEmbeddingが動作していることが分かります。

使用されるノードの選択状況をみてみましょう。

from llama_index.callbacks import CBEventType

node_list = llama_debug_handler.get_event_pairs(CBEventType.RETRIEVE)[0][1].payload["nodes"]
# node_list = response.source_nodes # こちらでも可
node_count = len(node_list)
print(f"{node_count=}")

for node in node_list:
    doc_id = node.node.id_
    score = node.score
    print(f"{doc_id=}, {score=}")
node_count=3
doc_id='bb813454-6ded-4a7c-87b9-b287a2662dd9', score=0.8604478015232282
doc_id='e20291ca-d400-4a4e-a356-c5fd1ca2a973', score=0.841312029237466
doc_id='023656f1-693f-407e-8598-a3e17d0bcc62', score=0.8380146604712428

3つのノードがスコアの高い順に抽出されていることが分かります。

注意点:クエリの都度ノードの埋め込みベクトルを求めてしまう

ListRetrieverMode.EMBEDDINGの場合、求めた埋め込みベクトルはデータストアなどに保存しているわけではないため、

ノード抽出時にその都度埋め込みベクトルを再計算するコストが掛かってしまう点は注意が必要です。

データストアを使用するには、埋め込みベクトルをIndexStoreのノード保存するか、もしくは別のSimpleVectorIndexなどを使う方でも無難に実現できます。

今回はListIndexの範囲に収まる前者の方法を見ていきます。

対策:IndexStoreに埋め込みベクトルを含める方法

まずはベースとなるListIndexを作成します。

from llama_index import SimpleDirectoryReader
from llama_index import Document
from llama_index import GPTListIndex
from llama_index import ServiceContext
from llama_index.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
from llama_index.indices.list.base import ListRetrieverMode

documents = SimpleDirectoryReader(input_dir="./data").load_data()

llama_debug_handler = LlamaDebugHandler()
callback_manager = CallbackManager([llama_debug_handler])
service_context = ServiceContext.from_defaults(callback_manager=callback_manager)

list_index = GPTListIndex.from_documents(documents
    , service_context=service_context)

そしてノードの一覧を取得して、埋め込みベクトルを求めて格納します。

# 埋め込みベクトルを計算
for doc_id, node in list_index.storage_context.docstore.docs.items():
    service_context.embed_model.queue_text_for_embedding(
        doc_id, node.text
    )
result_ids, result_embeddings = service_context.embed_model.get_queued_text_embeddings()

id_to_embed_map = {}
for new_id, text_embedding in zip(result_ids, result_embeddings):
    id_to_embed_map[new_id] = text_embedding

# ノードのembedding属性に埋め込みベクトルを格納
node_list = []
for doc_id, node in list_index.storage_context.docstore.docs.items():
    node.embedding = id_to_embed_map[doc_id]
    node_list.append(node)

# 修正したノードでインデックスを再構成
_ = list_index.build_index_from_nodes(node_list)

このようにしておけば、クエリ時に埋め込みベクトルをスキップすることができます。

from llama_index.indices.list.base import ListRetrieverMode

query_engine = list_index.as_query_engine(
    retriever_mode=ListRetrieverMode.EMBEDDING
    , similarity_top_k=3
)

response = query_engine.query("機械学習に関するアップデートについて300字前後で要約してください。")
**********
Trace: query
    |_query ->  9.887931 seconds
      |_retrieve ->  0.385121 seconds
        |_embedding ->  0.3596 seconds
      |_synthesize ->  9.50281 seconds
        |_llm ->  9.47453 seconds
**********

RETRIEVE時のEMBEDDING処理が1回だけ残っていますが、これはクエリ自体の埋め込みベクトルを求めているため、意図通り動いています。

まとめ

いかがでしたでしょうか。

本記事が、今後LlamaIndexをお使いになられる方の参考になれば幸いです。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.